{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "97ce156f",
   "metadata": {},
   "source": [
    "# 5-2,模型层layers\n",
    "\n",
    "深度学习模型一般由各种模型层组合而成。\n",
    "\n",
    "torch.nn中内置了非常丰富的各种模型层。它们都属于nn.Module的子类，具备参数管理功能。\n",
    "\n",
    "例如：\n",
    "\n",
    "* nn.Linear, nn.Flatten, nn.Dropout, nn.BatchNorm2d, nn.Embedding\n",
    "\n",
    "* nn.Conv2d,nn.AvgPool2d,nn.Conv1d,nn.ConvTranspose2d\n",
    "\n",
    "* nn.GRU,nn.LSTM\n",
    "\n",
    "* nn.Transformer\n",
    "\n",
    "如果这些内置模型层不能够满足需求，我们也可以通过继承nn.Module基类构建自定义的模型层。\n",
    "\n",
    "实际上，pytorch不区分模型和模型层，都是通过继承nn.Module进行构建。\n",
    "\n",
    "因此，我们只要继承nn.Module基类并实现forward方法即可自定义模型层。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3175636b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3bf8eb30",
   "metadata": {},
   "source": [
    "## 一，基础层"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa89b697",
   "metadata": {},
   "source": [
    "一些基础的内置模型层简单介绍如下。\n",
    "\n",
    "* nn.Linear：全连接层。参数个数 = 输入层特征数× 输出层特征数(weight)＋ 输出层特征数(bias)\n",
    "\n",
    "* nn.Embedding：嵌入层，一种比Onehot更加有效的对离散特征进行编码的方法。一般用于将输入中的单词映射为稠密向量。嵌入层的参数需要学习。\n",
    "\n",
    "* nn.Flatten：压平层，用于将多维张量样本压成一维张量样本。\n",
    "\n",
    "* nn.BatchNorm1d：一维批标准化层。通过线性变换将输入批次缩放平移到稳定的均值和标准差。可以增强模型对输入不同分布的适应性，加快模型训练速度，有轻微正则化效果。一般在激活函数之前使用。可以用afine参数设置该层是否含有可以训练的参数。\n",
    "\n",
    "* nn.BatchNorm2d：二维批标准化层。 常用于CV领域。\n",
    "\n",
    "* nn.BatchNorm3d：三维批标准化层。\n",
    "\n",
    "* nn.Dropout：一维随机丢弃层。一种正则化手段。\n",
    "\n",
    "* nn.Dropout2d：二维随机丢弃层。\n",
    "\n",
    "* nn.Dropout3d：三维随机丢弃层。\n",
    "\n",
    "* nn.Threshold：限幅层，当输入大于或小于阈值范围时，截断之。\n",
    "\n",
    "* nn.ConstantPad2d： 二维常数填充层。对二维张量样本填充常数扩展长度。\n",
    "\n",
    "* nn.ReplicationPad1d： 一维复制填充层。对一维张量样本通过复制边缘值填充扩展长度。\n",
    "\n",
    "* nn.ZeroPad2d：二维零值填充层。对二维张量样本在边缘填充0值.\n",
    "\n",
    "* nn.GroupNorm：组归一化。一种替代批归一化的方法，将通道分成若干组进行归一。不受batch大小限制。\n",
    "\n",
    "* nn.LayerNorm：层归一化。常用于NLP领域，不受序列长度不一致影响。\n",
    "\n",
    "* nn.InstanceNorm2d: 样本归一化。一般在图像风格迁移任务中效果较好。\n",
    "\n",
    "\n",
    "\n",
    "重点说说各种归一化层：\n",
    "\n",
    "$$y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta$$\n",
    "\n",
    "\n",
    "\n",
    "* 结构化数据的BatchNorm1D归一化 【结构化数据的主要区分度来自每个样本特征在全体样本中的排序，将全部样本的某个特征都进行相同的放大缩小平移操作，样本间的区分度基本保持不变，所以结构化数据可以做BatchNorm，但LayerNorm会打乱全体样本根据某个特征的排序关系，引起区分度下降】\n",
    "\n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5mbd2ill5j20a808z0ta.jpg)\n",
    "\n",
    "\n",
    "* 图片数据的各种归一化(一般常用BatchNorm2D)【图片数据的主要区分度来自图片中的纹理结构，所以图片数据的归一化一定要在图片的宽高方向上操作以保持纹理结构，此外在Batch维度上操作还能够引入少许的正则化，对提升精度有进一步的帮助。】\n",
    "\n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5m92dtnd0j20tn07ztab.jpg)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "* 文本数据的LayerNorm归一化 【文本数据的主要区分度来自于词向量(Embedding向量)的方向，所以文本数据的归一化一定要在 特征(通道)维度上操作 以保持 词向量方向不变。此外文本数据还有一个重要的特点是不同样本的序列长度往往不一样，所以不可以在Sequence和Batch维度上做归一化，否则将不可避免地让padding位置对应的向量变成非零向量】\n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5m903lv0nj20jc0iawfx.jpg)\n",
    "\n",
    "\n",
    "\n",
    "* 此外，有论文提出了一种可自适应学习的归一化：SwitchableNorm，可应用于各种场景且有一定的效果提升。【SwitchableNorm是将BN、LN、IN结合，赋予权重，让网络自己去学习归一化层应该使用什么方法。】\n",
    "\n",
    "参考论文：https://arxiv.org/pdf/1806.10779.pdf\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "对BatchNorm需要注意的几点：\n",
    "\n",
    "(1)BatchNorm放在激活函数前还是激活函数后？\n",
    "\n",
    "原始论文认为将BatchNorm放在激活函数前效果较好，后面的研究一般认为将BatchNorm放在激活函数之后更好。\n",
    "\n",
    "(2)BatchNorm在训练过程和推理过程的逻辑是否一样？\n",
    "\n",
    "不一样！训练过程BatchNorm的均值和方差和根据mini-batch中的数据估计的，而推理过程中BatchNorm的均值和方差是用的训练过程中的全体样本估计的。因此预测过程是稳定的，相同的样本不会因为所在批次的差异得到不同的结果，但训练过程中则会受到批次中其他样本的影响所以有正则化效果。\n",
    "\n",
    "(3)BatchNorm的精度效果与batch_size大小有何关系? \n",
    "\n",
    "如果受到GPU内存限制，不得不使用很小的batch_size，训练阶段时使用的mini-batch上的均值和方差的估计和预测阶段时使用的全体样本上的均值和方差的估计差异可能会较大，效果会变差。这时候，可以尝试LayerNorm或者GroupNorm等归一化方法。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8cf76da1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "channel mean: 2.9802322387695312e-08\n",
      "channel std: 1.0000009536743164\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "batch_size, channel, height, width = 32, 16, 128, 128\n",
    "\n",
    "tensor = torch.arange(0,32*16*128*128).view(32,16,128,128).float() \n",
    "\n",
    "bn = nn.BatchNorm2d(num_features=channel,affine=False)\n",
    "bn_out = bn(tensor)\n",
    "\n",
    "\n",
    "channel_mean = torch.mean(bn_out[:,0,:,:]) \n",
    "channel_std = torch.std(bn_out[:,0,:,:])\n",
    "print(\"channel mean:\",channel_mean.item())\n",
    "print(\"channel std:\",channel_std.item())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a2f406b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "token_mean: -2.7939677238464355e-09\n",
      "token_mean: 1.0002442598342896\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "batch_size, sequence, features = 32, 100, 2048\n",
    "tensor = torch.arange(0,32*100*2048).view(32,100,2048).float() \n",
    "\n",
    "ln = nn.LayerNorm(normalized_shape=[features],\n",
    "                  elementwise_affine = False)\n",
    "\n",
    "ln_out = ln(tensor)\n",
    "\n",
    "token_mean = torch.mean(ln_out[0,0,:]) \n",
    "token_std = torch.std(ln_out[0,0,:])\n",
    "print(\"token_mean:\",token_mean.item())\n",
    "print(\"token_mean:\",token_std.item())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa1b19de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6d05ae41",
   "metadata": {},
   "source": [
    "## 二，卷积网络相关层\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd45681a",
   "metadata": {},
   "source": [
    "一些与卷积相关的内置层介绍如下\n",
    "\n",
    "* nn.Conv1d：普通一维卷积，常用于文本。参数个数 = 输入通道数×卷积核尺寸(如3)×卷积核个数 + 卷积核尺寸(如3）\n",
    "  \n",
    "* nn.Conv2d：普通二维卷积，常用于图像。参数个数 = 输入通道数×卷积核尺寸(如3乘3)×卷积核个数 + 卷积核尺寸(如3乘3)。) 通过调整dilation参数大于1，可以变成空洞卷积，增加感受野。 通过调整groups参数不为1，可以变成分组卷积。分组卷积中每个卷积核仅对其对应的一个分组进行操作。 当groups参数数量等于输入通道数时，相当于tensorflow中的二维深度卷积层tf.keras.layers.DepthwiseConv2D。 利用分组卷积和1乘1卷积的组合操作，可以构造相当于Keras中的二维深度可分离卷积层tf.keras.layers.SeparableConv2D。\n",
    "\n",
    "* nn.Conv3d：普通三维卷积，常用于视频。参数个数 = 输入通道数×卷积核尺寸(如3乘3乘3)×卷积核个数 + 卷积核尺寸(如3乘3乘3) 。\n",
    "\n",
    "* nn.MaxPool1d: 一维最大池化。\n",
    "\n",
    "* nn.MaxPool2d：二维最大池化。一种下采样方式。没有需要训练的参数。\n",
    "\n",
    "* nn.MaxPool3d：三维最大池化。\n",
    "\n",
    "* nn.AdaptiveMaxPool2d：二维自适应最大池化。无论输入图像的尺寸如何变化，输出的图像尺寸是固定的。\n",
    "  该函数的实现原理，大概是通过输入图像的尺寸和要得到的输出图像的尺寸来反向推算池化算子的padding,stride等参数。\n",
    "  \n",
    "* nn.FractionalMaxPool2d：二维分数最大池化。普通最大池化通常输入尺寸是输出的整数倍。而分数最大池化则可以不必是整数。分数最大池化使用了一些随机采样策略，有一定的正则效果，可以用它来代替普通最大池化和Dropout层。\n",
    "\n",
    "* nn.AvgPool2d：二维平均池化。\n",
    "\n",
    "* nn.AdaptiveAvgPool2d：二维自适应平均池化。无论输入的维度如何变化，输出的维度是固定的。\n",
    "\n",
    "* nn.ConvTranspose2d：二维卷积转置层，俗称反卷积层。并非卷积的逆操作，但在卷积核相同的情况下，当其输入尺寸是卷积操作输出尺寸的情况下，卷积转置的输出尺寸恰好是卷积操作的输入尺寸。在语义分割中可用于上采样。\n",
    "\n",
    "* nn.Upsample：上采样层，操作效果和池化相反。可以通过mode参数控制上采样策略为\"nearest\"最邻近策略或\"linear\"线性插值策略。\n",
    "\n",
    "* nn.Unfold：滑动窗口提取层。其参数和卷积操作nn.Conv2d相同。实际上，卷积操作可以等价于nn.Unfold和nn.Linear以及nn.Fold的一个组合。\n",
    "  其中nn.Unfold操作可以从输入中提取各个滑动窗口的数值矩阵，并将其压平成一维。利用nn.Linear将nn.Unfold的输出和卷积核做乘法后，再使用\n",
    "  nn.Fold操作将结果转换成输出图片形状。\n",
    "\n",
    "* nn.Fold：逆滑动窗口提取层。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb129af1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "065c1f9b",
   "metadata": {},
   "source": [
    "重点说说各种常用的卷积层和上采样层：\n",
    "\n",
    "* 普通卷积【普通卷积的操作分成3个维度，在空间维度(H和W维度)是共享卷积核权重滑窗相乘求和(融合空间信息)，在输入通道维度是每一个通道使用不同的卷积核参数并对输入通道维度求和(融合通道信息)，在输出通道维度操作方式是并行堆叠(多种)，有多少个卷积核就有多少个输出通道】\n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5nhe0lsutg20az0aln03.gif)\n",
    "\n",
    "\n",
    "\n",
    "* 空洞卷积【和普通卷积相比，空洞卷积可以在保持较小参数规模的条件下增大感受野，常用于图像分割领域。其缺点是可能产生网格效应，即有些像素被空洞漏过无法利用到，可以通过使用不同膨胀因子的空洞卷积的组合来克服该问题，参考文章：https://developer.orbbec.com.cn/v/blog_detail/892 】 \n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5nhe0y7x1g20az0al0vu.gif)\n",
    "\n",
    "\n",
    "* 分组卷积 【和普通卷积相比，分组卷积将输入通道分成g组，卷积核也分成对应的g组，每个卷积核只在其对应的那组输入通道上做卷积，最后将g组结果堆叠拼接。由于每个卷积核只需要在全部输入通道的1/g个通道上做卷积，参数量降低为普通卷积的1/g。分组卷积要求输入通道和输出通道数都是g的整数倍。参考文章：https://zhuanlan.zhihu.com/p/65377955 】\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5npy1zyalj20ie0erwf8.jpg)\n",
    "\n",
    "\n",
    "\n",
    "* 深度可分离卷积【深度可分离卷积的思想是先用g=m(输入通道数)的分组卷积逐通道作用融合空间信息，再用n(输出通道数)个1乘1卷积融合通道信息。 其参数量为 (m×k×k)+ n×m, 相比普通卷积的参数量 m×n×k×k 显著减小 】\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5npbiuvzvj20uq0e7dge.jpg)\n",
    "\n",
    "\n",
    "* 转置卷积 【一般的卷积操作后会让特征图尺寸变小，但转置卷积(也被称为反卷积)可以实现相反的效果，即放大特征图尺寸。对两种方式理解转置卷积，第一种方式是转置卷积是一种特殊的卷积，通过设置合适的padding的大小来恢复特征图尺寸。第二种理解基于卷积运算的矩阵乘法表示方法，转置卷积相当于将卷积核对应的表示矩阵做转置，然后乘上输出特征图压平的一维向量，即可恢复原始输入特征图的大小。\n",
    "参考文章：https://zhuanlan.zhihu.com/p/115070523】\n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5ns98iiamj20v70u075e.jpg)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "* 上采样层 【除了使用转置卷积进行上采样外，在图像分割领域更多的时候一般是使用双线性插值的方式进行上采样，该方法没有需要学习的参数，通常效果也更好，除了双线性插值之外，还可以使用最邻近插值的方式进行上采样，但使用较少。】\n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5nsi5pt4eg20na0co74k.gif)\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "74ea2f3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--inputs--\n",
      "tensor([[[[ 0.,  1.,  2.,  3.,  4.],\n",
      "          [ 5.,  6.,  7.,  8.,  9.],\n",
      "          [10., 11., 12., 13., 14.],\n",
      "          [15., 16., 17., 18., 19.],\n",
      "          [20., 21., 22., 23., 24.]]]])\n",
      "--filters--\n",
      "tensor([[[[1., 1.],\n",
      "          [1., 1.]]]])\n",
      "--outputs--\n",
      "tensor([[[[12., 16., 20., 24.],\n",
      "          [32., 36., 40., 44.],\n",
      "          [52., 56., 60., 64.],\n",
      "          [72., 76., 80., 84.]]]]) \n",
      "\n",
      "--outputs(stride=2)--\n",
      "tensor([[[[12., 20.],\n",
      "          [52., 60.]]]]) \n",
      "\n",
      "--outputs(padding=1)--\n",
      "tensor([[[[ 0.,  1.,  3.,  5.,  7.,  4.],\n",
      "          [ 5., 12., 16., 20., 24., 13.],\n",
      "          [15., 32., 36., 40., 44., 23.],\n",
      "          [25., 52., 56., 60., 64., 33.],\n",
      "          [35., 72., 76., 80., 84., 43.],\n",
      "          [20., 41., 43., 45., 47., 24.]]]]) \n",
      "\n",
      "--outputs(dilation=2)--\n",
      "tensor([[[[24., 28., 32.],\n",
      "          [44., 48., 52.],\n",
      "          [64., 68., 72.]]]]) \n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "import torch.nn.functional as F \n",
    "\n",
    "# 卷积输出尺寸计算公式 o = (i + 2*p -k')//s  + 1 \n",
    "# 对空洞卷积 k' = d(k-1) + 1\n",
    "# o是输出尺寸，i 是输入尺寸，p是 padding大小， k 是卷积核尺寸， s是stride步长, d是dilation空洞参数\n",
    "\n",
    "inputs = torch.arange(0,25).view(1,1,5,5).float() # i= 5\n",
    "filters = torch.tensor([[[[1.0,1],[1,1]]]]) # k = 2\n",
    "\n",
    "outputs = F.conv2d(inputs, filters) # o = (5+2*0-2)//1+1 = 4\n",
    "outputs_s2 = F.conv2d(inputs, filters, stride=2)  #o = (5+2*0-2)//2+1 = 2\n",
    "outputs_p1 = F.conv2d(inputs, filters, padding=1) #o = (5+2*1-2)//1+1 = 6\n",
    "outputs_d2 = F.conv2d(inputs,filters, dilation=2) #o = (5+2*0-(2(2-1)+1))//1+1 = 3\n",
    "\n",
    "print(\"--inputs--\")\n",
    "print(inputs)\n",
    "print(\"--filters--\")\n",
    "print(filters)\n",
    "\n",
    "print(\"--outputs--\")\n",
    "print(outputs,\"\\n\")\n",
    "\n",
    "print(\"--outputs(stride=2)--\")\n",
    "print(outputs_s2,\"\\n\")\n",
    "\n",
    "print(\"--outputs(padding=1)--\")\n",
    "print(outputs_p1,\"\\n\")\n",
    "\n",
    "print(\"--outputs(dilation=2)--\")\n",
    "print(outputs_d2,\"\\n\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "27bdd558",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features.shape: torch.Size([8, 64, 128, 128])\n",
      "\n",
      "\n",
      "--conv--\n",
      "conv_out.shape: torch.Size([8, 32, 126, 126])\n",
      "conv.weight.shape: torch.Size([32, 64, 3, 3])\n",
      "\n",
      "\n",
      "--group conv--\n",
      "group_out.shape: torch.Size([8, 32, 126, 126])\n",
      "conv_group.weight.shape: torch.Size([32, 8, 3, 3])\n",
      "\n",
      "\n",
      "--separable conv--\n",
      "separable_out.shape: torch.Size([8, 32, 126, 126])\n",
      "depth_conv.weight.shape: torch.Size([64, 1, 3, 3])\n",
      "oneone_conv.weight.shape: torch.Size([32, 64, 1, 1])\n",
      "\n",
      "\n",
      "--conv transpose--\n",
      "features_like.shape: torch.Size([8, 64, 128, 128])\n",
      "conv_t.weight.shape: torch.Size([32, 64, 3, 3])\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "features = torch.randn(8,64,128,128)\n",
    "print(\"features.shape:\",features.shape)\n",
    "print(\"\\n\")\n",
    "\n",
    "#普通卷积\n",
    "print(\"--conv--\")\n",
    "conv = nn.Conv2d(in_channels=64,out_channels=32,kernel_size=3)\n",
    "conv_out = conv(features)\n",
    "print(\"conv_out.shape:\",conv_out.shape) \n",
    "print(\"conv.weight.shape:\",conv.weight.shape)\n",
    "print(\"\\n\")\n",
    "\n",
    "#分组卷积\n",
    "print(\"--group conv--\")\n",
    "conv_group = nn.Conv2d(in_channels=64,out_channels=32,kernel_size=3,groups=8)\n",
    "group_out = conv_group(features)\n",
    "print(\"group_out.shape:\",group_out.shape) \n",
    "print(\"conv_group.weight.shape:\",conv_group.weight.shape)\n",
    "print(\"\\n\")\n",
    "\n",
    "#深度可分离卷积\n",
    "print(\"--separable conv--\")\n",
    "depth_conv = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,groups=64)\n",
    "oneone_conv = nn.Conv2d(in_channels=64,out_channels=32,kernel_size=1)\n",
    "separable_conv = nn.Sequential(depth_conv,oneone_conv)\n",
    "separable_out = separable_conv(features)\n",
    "print(\"separable_out.shape:\",separable_out.shape) \n",
    "print(\"depth_conv.weight.shape:\",depth_conv.weight.shape)\n",
    "print(\"oneone_conv.weight.shape:\",oneone_conv.weight.shape)\n",
    "print(\"\\n\")\n",
    "\n",
    "#转置卷积\n",
    "print(\"--conv transpose--\")\n",
    "conv_t = nn.ConvTranspose2d(in_channels=32,out_channels=64,kernel_size=3)\n",
    "features_like = conv_t(conv_out)\n",
    "print(\"features_like.shape:\",features_like.shape)\n",
    "print(\"conv_t.weight.shape:\",conv_t.weight.shape)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d22c1c29",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs:\n",
      "tensor([[[[1., 2.],\n",
      "          [3., 4.]]]])\n",
      "\n",
      "\n",
      "nearest(inputs)：\n",
      "tensor([[[[1., 1., 2., 2.],\n",
      "          [1., 1., 2., 2.],\n",
      "          [3., 3., 4., 4.],\n",
      "          [3., 3., 4., 4.]]]])\n",
      "\n",
      "\n",
      "bilinear(inputs)：\n",
      "tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],\n",
      "          [1.6667, 2.0000, 2.3333, 2.6667],\n",
      "          [2.3333, 2.6667, 3.0000, 3.3333],\n",
      "          [3.0000, 3.3333, 3.6667, 4.0000]]]])\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "inputs = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)\n",
    "print(\"inputs:\")\n",
    "print(inputs)\n",
    "print(\"\\n\")\n",
    "\n",
    "nearest = nn.Upsample(scale_factor=2, mode='nearest')\n",
    "bilinear = nn.Upsample(scale_factor=2,mode=\"bilinear\",align_corners=True)\n",
    "\n",
    "print(\"nearest(inputs)：\")\n",
    "print(nearest(inputs))\n",
    "print(\"\\n\")\n",
    "print(\"bilinear(inputs)：\")\n",
    "print(bilinear(inputs)) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3b8b35c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "094132b2",
   "metadata": {},
   "source": [
    "## 三，循环网络相关层"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab158895",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "* nn.LSTM：长短记忆循环网络层【支持多层】。最普遍使用的循环网络层。具有携带轨道，遗忘门，更新门，输出门。可以较为有效地缓解梯度消失问题，从而能够适用长期依赖问题。设置bidirectional = True时可以得到双向LSTM。需要注意的时，默认的输入和输出形状是(seq,batch,feature), 如果需要将batch维度放在第0维，则要设置batch_first参数设置为True。\n",
    "\n",
    "* nn.GRU：门控循环网络层【支持多层】。LSTM的低配版，不具有携带轨道，参数数量少于LSTM，训练速度更快。\n",
    "\n",
    "* nn.RNN：简单循环网络层【支持多层】。容易存在梯度消失，不能够适用长期依赖问题。一般较少使用。\n",
    "\n",
    "* nn.LSTMCell：长短记忆循环网络单元。和nn.LSTM在整个序列上迭代相比，它仅在序列上迭代一步。一般较少使用。\n",
    "\n",
    "* nn.GRUCell：门控循环网络单元。和nn.GRU在整个序列上迭代相比，它仅在序列上迭代一步。一般较少使用。\n",
    "\n",
    "* nn.RNNCell：简单循环网络单元。和nn.RNN在整个序列上迭代相比，它仅在序列上迭代一步。一般较少使用。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e4ce392",
   "metadata": {},
   "source": [
    "重点介绍一下LSTM和GRU \n",
    "\n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h5rzqa54z6j20u00j30u3.jpg)\n",
    "\n",
    "一般地，各种RNN序列模型层(RNN,GRU,LSTM等)可以用函数表示如下:\n",
    "\n",
    "$$h_t = f(h_{t-1},x_t)$$\n",
    "\n",
    "这个公式的含义是：t时刻循环神经网络的输出向量$h_t$由t-1时刻的输出向量$h_{t-1}$和t时刻的输入$i_t$变换而来。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ed59b95",
   "metadata": {},
   "source": [
    "* LSTM 结构解析 \n",
    "\n",
    "参考文章：《人人都能看懂的LSTM》https://zhuanlan.zhihu.com/p/32085405\n",
    "\n",
    "LSTM通过引入了三个门来控制信息的传递，分别是遗忘门，输入门 和输出门 。三个门的作用为：\n",
    "\n",
    "（1）遗忘门: 遗忘门$f_t$控制上一时刻的内部状态  需要遗忘多少信息；\n",
    "\n",
    "（2）输入门: 输入门$i_t$控制当前时刻的候选状态  有多少信息需要保存；\n",
    "\n",
    "（3）输出门: 输出门$o_t$控制当前时刻的内部状态  有多少信息需要输出给外部状态  ；\n",
    "\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "i_{t}=\\sigma\\left(W_{i} x_{t}+U_{i} h_{t-1}+b_{i}\\right) \\tag{1} \\\\\n",
    "f_{t}=\\sigma\\left(W_{f} x_{t}+U_{f} h_{t-1}+b_{f}\\right) \\tag{2} \\\\\n",
    "o_{t}=\\sigma\\left(W_{o} x_{t}+U_{o} h_{t-1}+b_{o}\\right) \\tag{3} \\\\\n",
    "\\tilde{c}_{t}=\\tanh \\left(W_{c} x_{t}+U_{c} h_{t-1}+b_{c}\\right) \\tag{4} \\\\\n",
    "c_{t}=f_{t} \\odot c_{t-1}+i_{t} \\odot \\tilde{c}_{t} \\tag{5} \\\\\n",
    "h_{t}=o_{t} \\odot \\tanh \\left(c_{t}\\right) \\tag{6}\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d52277c4",
   "metadata": {},
   "source": [
    "* GRU 结构解析\n",
    "\n",
    "参考文章：《人人都能看懂的GRU》https://zhuanlan.zhihu.com/p/32481747\n",
    "\n",
    "GRU的结构比LSTM更为简单一些，GRU只有两个门，更新门和重置门  。\n",
    "\n",
    "（1）更新门：更新门用于控制每一步$h_t$被更新的比例，更新门越大，$h_t$更新幅度越大。\n",
    "\n",
    "（2）重置门：重置门用于控制更新候选向量$\\tilde{h}_{t}$中前一步的状态$h_{t-1}$被重新放入的比例，重置门越大，更新候选向量中$h_{t-1}$被重新放进来的比例越大。\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "公式中的小圈表示哈达玛积，也就是两个向量逐位相乘。\n",
    "\n",
    "其中(1)式和(2)式计算的是更新门$u_t$和重置门$r_t$，是两个长度和$h_t$相同的向量。\n",
    "\n",
    "\n",
    "注意到(4)式 实际上和ResNet的残差结构是相似的，都是 f(x) = x + g(x) 的形式，可以有效地防止长序列学习反向传播过程中梯度消失问题。\n",
    "\n",
    "\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "z_{t}=\\sigma\\left(W_{z} x_{t}+U_{z} h_{t-1}+b_{z}\\right)\\tag{1} \\\\\n",
    "r_{t}=\\sigma\\left(W_{r} x_{t}+U_{r} h_{t-1}+b_{r}\\right) \\tag{2}\\\\\n",
    "\\tilde{h}_{t}=\\tanh \\left(W_{h} x_{t}+U_{h}\\left(r_{t} \\odot h_{t-1}\\right)+b_{h}\\right) \\tag{3}\\\\\n",
    "h_{t}= h_{t-1} - z_{t}\\odot h_{t-1}  + z_{t} \\odot  \\tilde{h}_{t} \\tag{4}\n",
    "\\end{align}\n",
    "$$\n",
    "GRU的参数数量为LSTM的3/4.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7ce71a78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--GRU--\n",
      "gru_output.shape: torch.Size([8, 200, 32])\n",
      "gru_hn.shape: torch.Size([1, 8, 32])\n",
      "\n",
      "\n",
      "--LSTM--\n",
      "lstm_output.shape: torch.Size([8, 200, 32])\n",
      "lstm_hn.shape: torch.Size([1, 8, 32])\n",
      "lstm_cn.shape: torch.Size([1, 8, 32])\n",
      "--------------------------------------------------------------------------\n",
      "Layer (type)                            Output Shape              Param #\n",
      "==========================================================================\n",
      "GRU-1                                  [-1, 200, 32]                9,408\n",
      "==========================================================================\n",
      "Total params: 9,408\n",
      "Trainable params: 9,408\n",
      "Non-trainable params: 0\n",
      "--------------------------------------------------------------------------\n",
      "Input size (MB): 0.000076\n",
      "Forward/backward pass size (MB): 0.048828\n",
      "Params size (MB): 0.035889\n",
      "Estimated Total Size (MB): 0.084793\n",
      "--------------------------------------------------------------------------\n",
      "--------------------------------------------------------------------------\n",
      "Layer (type)                            Output Shape              Param #\n",
      "==========================================================================\n",
      "LSTM-1                                 [-1, 200, 32]               12,544\n",
      "==========================================================================\n",
      "Total params: 12,544\n",
      "Trainable params: 12,544\n",
      "Non-trainable params: 0\n",
      "--------------------------------------------------------------------------\n",
      "Input size (MB): 0.000076\n",
      "Forward/backward pass size (MB): 0.048828\n",
      "Params size (MB): 0.047852\n",
      "Estimated Total Size (MB): 0.096756\n",
      "--------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "inputs = torch.randn(8,200,64) #batch_size, seq_length, features\n",
    "\n",
    "gru = nn.GRU(input_size=64,hidden_size=32,num_layers=1,batch_first=True)\n",
    "gru_output,gru_hn = gru(inputs)\n",
    "print(\"--GRU--\")\n",
    "print(\"gru_output.shape:\",gru_output.shape)\n",
    "print(\"gru_hn.shape:\",gru_hn.shape)\n",
    "print(\"\\n\")\n",
    "\n",
    "\n",
    "print(\"--LSTM--\")\n",
    "lstm = nn.LSTM(input_size=64,hidden_size=32,num_layers=1,batch_first=True)\n",
    "lstm_output,(lstm_hn,lstm_cn) = lstm(inputs)\n",
    "print(\"lstm_output.shape:\",lstm_output.shape)\n",
    "print(\"lstm_hn.shape:\",lstm_hn.shape)\n",
    "print(\"lstm_cn.shape:\",lstm_cn.shape)\n",
    "\n",
    "\n",
    "from torchkeras import summary\n",
    "summary(gru,input_data=inputs);\n",
    "summary(lstm,input_data=inputs);\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ea31566a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.75"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "9408/12544 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a3282c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9c1253cc",
   "metadata": {},
   "source": [
    "## 四，Transformer相关层"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40cae9f0",
   "metadata": {},
   "source": [
    "* nn.Transformer：Transformer网络结构。Transformer网络结构是替代循环网络的一种结构，解决了循环网络难以并行，难以捕捉长期依赖的缺陷。它是目前NLP任务的主流模型的主要构成部分。\n",
    "\n",
    "* nn.TransformerEncoder：Transformer编码器结构。由多个 nn.TransformerEncoderLayer编码器层组成。\n",
    "\n",
    "* nn.TransformerDecoder：Transformer解码器结构。由多个 nn.TransformerDecoderLayer解码器层组成。\n",
    "\n",
    "* nn.TransformerEncoderLayer：Transformer的编码器层。主要由Multi-Head self-Attention, Feed-Forward前馈网络, LayerNorm归一化层, 以及残差连接层组成。\n",
    "\n",
    "* nn.TransformerDecoderLayer：Transformer的解码器层。主要由Masked Multi-Head self-Attention, Multi-Head cross-Attention, Feed-Forward前馈网络, LayerNorm归一化层, 以及残差连接层组成。\n",
    "\n",
    "* nn.MultiheadAttention：多头注意力层。用于在序列方向上融合特征。使用的是Scaled Dot Production Attention，并引入了多个注意力头。\n",
    "\n",
    "\n",
    "$$\\operatorname{Attention}(Q, K, V)=\\operatorname{softmax}\\left(\\frac{Q K^{T}}{\\sqrt{d_{k}}}\\right) V$$ \n",
    "\\begin{aligned}\n",
    "\\operatorname{MultiHead}(Q, K, V) &=\\operatorname{Concat}\\left(\\operatorname{head}_{1}, \\ldots, \\text { head }_{\\mathrm{h}}\\right) W^{O} \\\\\n",
    "\\text { where }\\, head_{i} &=\\operatorname{Attention}\\left(Q W_{i}^{Q}, K W_{i}^{K}, V W_{i}^{V}\\right)\n",
    "\\end{aligned}\n",
    "\n",
    "\n",
    "![](./data/5-2-Transformer结构.jpg)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57f81dc5",
   "metadata": {},
   "source": [
    "参考阅读材料： \n",
    "\n",
    "Transformer知乎原理讲解：https://zhuanlan.zhihu.com/p/48508221\n",
    "\n",
    "Transformer哈佛博客代码讲解：http://nlp.seas.harvard.edu/annotated-transformer/ \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fed0f30d",
   "metadata": {},
   "source": [
    "主要对Transformer的要点问题做一些梳理：\n",
    "\n",
    "\n",
    "1，Transformer是如何解决长距离依赖的问题的？\n",
    "\n",
    "Transformer是通过引入Scale-Dot-Product注意力机制来融合序列上不同位置的信息，从而解决长距离依赖问题。以文本数据为例，在循环神经网络LSTM结构中，输入序列上相距很远的两个单词无法直接发生交互，只能通过隐藏层输出或者细胞状态按照时间步骤一个一个向后进行传递。对于两个在序列上相距非常远的单词，中间经过的其它单词让隐藏层输出和细胞状态混入了太多的信息，很难有效地捕捉这种长距离依赖特征。但是在Scale-Dot-Product注意力机制中，序列上的每个单词都会和其它所有单词做一次点积计算注意力得分，这种注意力机制中单词之间的交互是强制的不受距离影响的，所以可以解决长距离依赖问题。\n",
    "\n",
    "\n",
    "2，Transformer在训练和测试阶段可以在时间(序列)维度上进行并行吗？\n",
    "\n",
    "在训练阶段，Encoder和Decoder在时间(序列)维度都是并行的，在测试阶段，Encoder在序列维度是并行的，Decoder是串行的。\n",
    "\n",
    "首先，Encoder部分在训练阶段和预测阶段都可以并行比较好理解，无论在训练还是预测阶段，它干的事情都是把已知的完整输入编码成memory，在序列维度可以并行。\n",
    "\n",
    "对于Decoder部分有些微妙。在预测阶段Decoder肯定是不能并行的，因为Decoder实际上是一个自回归，它前面k-1位置的输出会变成第k位的输入的。前面没有计算完，后面是拿不到输入的，肯定不可以并行。那么训练阶段能否并行呢？虽然训练阶段知道了全部的解码结果，但是训练阶段要和预测阶段一致啊，前面的解码输出不能受到后面解码结果的影响啊。但Transformer通过在Decoder中巧妙地引入Mask技巧，使得在用Attention机制做序列特征融合的时候，每个单词对位于它之后的单词的注意力得分都为0，这样就保证了前面的解码输出不会受到后面解码结果的影响，因此Decoder在训练阶段可以在序列维度做并行。\n",
    "\n",
    "\n",
    "3，Scaled-Dot Product Attention为什么要除以$\\sqrt{d_k}$?\n",
    "\n",
    "为了避免$d_k$变得很大时softmax函数的梯度趋于0。假设Q和K中的取出的两个向量$q$和$k$的每个元素值都是正态随机分布，数学上可以证明两个独立的正态随机变量的积依然是一个正态随机变量，那么两个向量做点积，会得到$d_k$个正态随机变量的和，数学上$d_k$个正态随机变量的和依然是一个正态随机变量，其方差是原来的$d_k$倍，标准差是原来的$\\sqrt{d_k}$倍。如果不做scale, 当$d_k$很大时，求得的$QK^T$元素的绝对值容易很大，导致落在softmax的极端区域(趋于0或者1)，极端区域softmax函数的梯度值趋于0，不利于模型学习。除以$\\sqrt{d_k}$，恰好做了归一，不受$d_k$变化影响。\n",
    "\n",
    "\n",
    "4，MultiHeadAttention的参数数量和head数量有何关系?\n",
    "\n",
    "MultiHeadAttention的参数数量和head数量无关。多头注意力的参数来自对QKV的三个变换矩阵以及多头结果concat后的输出变换矩阵。假设嵌入向量的长度是d_model, 一共有h个head. 对每个head，$W_{i}^{Q},W_{i}^{K},W_{i}^{V}$ 这三个变换矩阵的尺寸都是 d_model×(d_model/h)，所以h个head总的参数数量就是3×d_model×(d_model/h)×h = 3×d_model×d_model。它们的输出向量长度都变成 d_model/h，经过attention作用后向量长度保持，h个head的输出拼接到一起后向量长度还是d_model，所以最后输出变换矩阵的尺寸是d_model×d_model。因此，MultiHeadAttention的参数数量为 4×d_model×d_model，和head数量无关。\n",
    "\n",
    "\n",
    "5，Transformer有什么缺点？\n",
    "\n",
    "Transformer主要的缺点有两个，一个是注意力机制相对序列长度的复杂度是O(n^2)，第二个是对位置信息的。\n",
    "第一，Transformer在用Attention机制做序列特征融合的时候，每两个单词之间都要计算点积获得注意力得分，这个计算复杂度和序列的长度平方成正比，对于一些特别长的序列，可能存在着性能瓶颈，有一些针对这个问题的改进方案如Linformer。\n",
    "第二个是Transformer通过引入注意力机制两两位置做点乘来融合序列特征，而不是像循环神经网络那样由先到后地处理序列中的数据，导致丢失了单词之间的位置信息关系，通过在输入中引入正余弦函数构造的位置编码PositionEncoding一定程度上补充了位置信息，但还是不如循环神经网络那样自然和高效。\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f468519a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------\n",
      "Layer (type)                            Output Shape              Param #\n",
      "==========================================================================\n",
      "MultiheadAttention-1                   [-1, 200, 64]               16,640\n",
      "==========================================================================\n",
      "Total params: 16,640\n",
      "Trainable params: 16,640\n",
      "Non-trainable params: 0\n",
      "--------------------------------------------------------------------------\n",
      "Input size (MB): 0.000000\n",
      "Forward/backward pass size (MB): 0.097656\n",
      "Params size (MB): 0.063477\n",
      "Estimated Total Size (MB): 0.161133\n",
      "--------------------------------------------------------------------------\n",
      "--------------------------------------------------------------------------\n",
      "Layer (type)                            Output Shape              Param #\n",
      "==========================================================================\n",
      "MultiheadAttention-1                   [-1, 200, 64]               16,640\n",
      "==========================================================================\n",
      "Total params: 16,640\n",
      "Trainable params: 16,640\n",
      "Non-trainable params: 0\n",
      "--------------------------------------------------------------------------\n",
      "Input size (MB): 0.000000\n",
      "Forward/backward pass size (MB): 0.097656\n",
      "Params size (MB): 0.063477\n",
      "Estimated Total Size (MB): 0.161133\n",
      "--------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "\n",
    "#验证MultiheadAttention和head数量无关\n",
    "inputs = torch.randn(8,200,64) #batch_size, seq_length, features\n",
    "\n",
    "attention_h8 = nn.MultiheadAttention(\n",
    "    embed_dim = 64,\n",
    "    num_heads = 8,\n",
    "    bias=True,\n",
    "    batch_first=True\n",
    ")\n",
    "\n",
    "attention_h16 = nn.MultiheadAttention(\n",
    "    embed_dim = 64,\n",
    "    num_heads = 16,\n",
    "    bias=True,\n",
    "    batch_first=True\n",
    ")\n",
    "\n",
    "\n",
    "out_h8 = attention_h8(inputs,inputs,inputs)\n",
    "out_h16 = attention_h16(inputs,inputs,inputs)\n",
    "\n",
    "from torchkeras import summary \n",
    "summary(attention_h8,input_data_args=(inputs,inputs,inputs));\n",
    "\n",
    "summary(attention_h16,input_data_args=(inputs,inputs,inputs));\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "df6ad71e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn \n",
    "from copy import deepcopy\n",
    "\n",
    "#多头注意力的一种简洁实现\n",
    "\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    \"Compute 'Scaled Dot Product Attention'\"\n",
    "    def __init__(self):\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "\n",
    "    def forward(self,query, key, value, mask=None, dropout=None):\n",
    "        d_k = query.size(-1)\n",
    "        scores = query@key.transpose(-2,-1) / d_k**0.5     \n",
    "        if mask is not None:\n",
    "            scores = scores.masked_fill(mask == 0, -1e20)\n",
    "        p_attn = F.softmax(scores, dim = -1)\n",
    "        if dropout is not None:\n",
    "            p_attn = dropout(p_attn)\n",
    "        return p_attn@value, p_attn\n",
    "    \n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, h, d_model, dropout=0.1):\n",
    "        \"Take in model size and number of heads.\"\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        assert d_model % h == 0\n",
    "        # We assume d_v always equals d_k\n",
    "        self.d_k = d_model // h\n",
    "        self.h = h\n",
    "        self.linears = nn.ModuleList([deepcopy(nn.Linear(d_model, d_model)) for _ in range(4)])\n",
    "        \n",
    "        self.attn = None\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "        self.attention = ScaledDotProductAttention()\n",
    "        \n",
    "    def forward(self, query, key, value, mask=None):\n",
    "        \"Implements Figure 2\"\n",
    "        if mask is not None:\n",
    "            # Same mask applied to all h heads.\n",
    "            mask = mask.unsqueeze(1)\n",
    "        nbatches = query.size(0)\n",
    "        \n",
    "        # 1) Do all the linear projections in batch from d_model => h x d_k \n",
    "        query, key, value = \\\n",
    "            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)\n",
    "             for l, x in zip(self.linears, (query, key, value))]\n",
    "        \n",
    "        # 2) Apply attention on all the projected vectors in batch. \n",
    "        x, self.attn = self.attention(query, key, value, mask=mask, \n",
    "                                 dropout=self.dropout)\n",
    "        \n",
    "        # 3) \"Concat\" using a view and apply a final linear. \n",
    "        x = x.transpose(1, 2).contiguous() \\\n",
    "             .view(nbatches, -1, self.h * self.d_k)\n",
    "        return self.linears[-1](x)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4da2d80",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7373a7f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "58c8c14c",
   "metadata": {},
   "source": [
    "## 五，自定义模型层"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1b9f9b1",
   "metadata": {},
   "source": [
    "如果Pytorch的内置模型层不能够满足需求，我们也可以通过继承nn.Module基类构建自定义的模型层。\n",
    "\n",
    "实际上，pytorch不区分模型和模型层，都是通过继承nn.Module进行构建。\n",
    "\n",
    "因此，我们只要继承nn.Module基类并实现forward方法即可自定义模型层。\n",
    "\n",
    "下面是Pytorch的nn.Linear层的源码，我们可以仿照它来自定义模型层。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "664b53bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Linear(nn.Module):\n",
    "    __constants__ = ['in_features', 'out_features']\n",
    "\n",
    "    def __init__(self, in_features, out_features, bias=True):\n",
    "        super(Linear, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))\n",
    "        if bias:\n",
    "            self.bias = nn.Parameter(torch.Tensor(out_features))\n",
    "        else:\n",
    "            self.register_parameter('bias', None)\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n",
    "        if self.bias is not None:\n",
    "            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)\n",
    "            bound = 1 / math.sqrt(fan_in)\n",
    "            nn.init.uniform_(self.bias, -bound, bound)\n",
    "\n",
    "    def forward(self, input):\n",
    "        return F.linear(input, self.weight, self.bias)\n",
    "\n",
    "    def extra_repr(self):\n",
    "        return 'in_features={}, out_features={}, bias={}'.format(\n",
    "            self.in_features, self.out_features, self.bias is not None\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e12a67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a5ff24cd",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,md"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
